package com.hyjiacan.tools.apps.dbm.handlers

import com.hyjiacan.tools.apps.dbm.models.DbmRequestForm
import com.hyjiacan.tools.utils.Json
import com.hyjiacan.tools.apps.dbm.utils.Misc
import com.hyjiacan.tools.apps.dbm.utils.MySQL
import com.hyjiacan.tools.http.entities.QueryResult
import com.hyjiacan.tools.http.handlers.AjaxHandler
import com.hyjiacan.tools.http.models.RequestContext
import java.io.BufferedInputStream
import java.io.File
import java.io.InputStream
import java.io.PrintWriter
import java.nio.file.Paths
import java.sql.Connection
import java.sql.ResultSet
import java.sql.SQLException
import java.sql.Types
import java.time.LocalDateTime
import java.time.format.DateTimeFormatter
import java.util.*

class DatabaseHandler : AjaxHandler() {
    companion object {
        const val INFORMATION_STATEMENTS: String =
            "select TABLE_NAME, TABLE_COMMENT, ENGINE, ROW_FORMAT, TABLE_ROWS, TABLE_COLLATION, TABLE_TYPE from information_schema.`tables` where table_schema='#database#';\n" +
                    "select TABLE_NAME, VIEW_DEFINITION, DEFINER from information_schema.views where table_schema='#database#';\n" +
                    "select ROUTINE_NAME, ROUTINE_TYPE, ROUTINE_COMMENT, ROUTINE_DEFINITION, CREATED, LAST_ALTERED, DEFINER from information_schema.`routines` where routine_schema='#database#';\n" +
                    "select EVENT_NAME, EVENT_DEFINITION, EVENT_TYPE, INTERVAL_VALUE, INTERVAL_FIELD, EVENT_COMMENT, STARTS, ENDS, STATUS, CREATED, LAST_ALTERED, LAST_EXECUTED from information_schema.`events` where event_schema='#database#';\n" +
                    "select TRIGGER_NAME, EVENT_MANIPULATION, EVENT_OBJECT_TABLE, ACTION_ORDER, ACTION_CONDITION, ACTION_STATEMENT, ACTION_TIMING,CREATED, DEFINER from information_schema.`triggers` where trigger_schema='#database#';\n" +
                    "show variables;\n"
    }

    override fun get(request: RequestContext): Any {
        when (request.path) {
            "/api/v1/database/export/download" -> {
                return downloadExport(request)
            }

            else -> {
                throw Exception("Path not found")
            }
        }
    }

    override fun post(request: RequestContext): Any {
        val form = DbmRequestForm(request.form)
        when (request.path) {
            "/api/v1/database/connect" -> {
                return connect(form)
            }

            "/api/v1/database/info" -> {
                return info(form)
            }

            "/api/v1/database/table/columns" -> {
                return tableColumns(form)
            }

            "/api/v1/database/table/primary_k" -> {
                return primaryKey(form)
            }

            "/api/v1/database/table/definition" -> {
                return tableDefinition(form)
            }

            "/api/v1/database/execute" -> {
                return execute(form)
            }

            "/api/v1/database/export" -> {
                return export(request, form)
            }

            "/api/v1/database/exportDoc" -> {
                return exportDoc(request, form)
            }

            else -> {
                throw Exception("Path not found")
            }
        }
    }

    override fun put(request: RequestContext): Any {
        TODO("Not yet implemented")
    }

    override fun delete(request: RequestContext): Any {
        TODO("Not yet implemented")
    }

    @Throws(SQLException::class)
    fun connect(params: DbmRequestForm): Any {
        return MySQL.execute(params.options, "show databases", 0, 0)
    }

    @Throws(Exception::class)
    fun info(params: DbmRequestForm): Any {
        val connection = Misc.decodeOption(params.options)
        val statement = INFORMATION_STATEMENTS.replace("#database#".toRegex(), connection.database)
        return MySQL.execute(params.options, statement, 0, 0)
    }

    @Throws(Exception::class)
    fun tableColumns(params: DbmRequestForm): Any {
        val connection = Misc.decodeOption(params.options)

        val temp = params.table.split(".")
        val table: String
        var database = connection.database
        if (temp.size == 1) {
            table = temp[0]
        } else {
            database = temp[0]
            table = temp[1]
        }

        val statement = arrayOf(
            "select COLUMN_NAME, IS_NULLABLE, COLUMN_DEFAULT, COLUMN_TYPE, COLUMN_KEY, EXTRA, COLUMN_COMMENT from information_schema.columns",
            "        where table_schema='#database#' and table_name='#table_name#' order by ORDINAL_POSITION"
        ).joinToString("\n").replace("#database#".toRegex(), database)
            .replace("#table_name#".toRegex(), table)
        return MySQL.execute(params.options, statement, 0, 0)
    }

    @Throws(Exception::class)
    fun primaryKey(params: DbmRequestForm): Any {
        val connection = Misc.decodeOption(params.options)
        val statement = arrayOf(
            "select GROUP_CONCAT(COLUMN_NAME ORDER BY ORDINAL_POSITION) as key  from information_schema.key_column_usage",
            "        where table_schema='#database#' and table_name='#table_name#' and constraint_name='PRIMARY'"
        ).joinToString("\n").replace("#database#".toRegex(), connection.database)
            .replace("#table_name#".toRegex(), params.table)
        return MySQL.execute(params.options, statement, 0, 0)
    }

    @Throws(Exception::class)
    fun tableDefinition(params: DbmRequestForm): Any {
        val statement = "show create table `#table_name#`".replace("#table_name#".toRegex(), params.table)
        return MySQL.execute(params.options, statement, 0, 0)
    }

    @Throws(SQLException::class)
    fun execute(params: DbmRequestForm): Any {
        return MySQL.execute(params.options, params.statement, params.fetchOffset, params.fetchSize, params.args)
    }

    @Throws(SQLException::class)
    fun downloadExport(request: RequestContext): Any {
        val filename = request.params["filename"]

        if (filename == null || filename.contains("\\") || filename.contains("/")) {
            return "非法请求"
        }

        val file: File

        try {
            val fullname = request.session!!.get(filename).toString()
            request.session.remove(filename)
            file = File(fullname)
        } catch (e: Exception) {
            return "非法请求"
        }

        if (file.exists()) {
            return file
        }

        return "文件不存在"
    }

    /**
     * 导出数据的文档
     */
    @Throws(SQLException::class)
    fun exportDoc(request: RequestContext, params: DbmRequestForm): Any {
        val qr = QueryResult()

        try {
            MySQL.connect(params.options) { connection: Connection ->
                val now = LocalDateTime.now().format(
                    DateTimeFormatter.ofPattern("yyyyMMdd_HHmmss")
                )
                val dbname = connection.catalog
                val mdFile = "webdbm_export_${dbname}_doc_${now}.md"
                val excelFile = "webdbm_export_${dbname}_doc_${now}.xlsx"

                val mdTempFile = Paths.get(Misc.getTmpPath(), mdFile).toFile()
                val excelTempFile = Paths.get(Misc.getTmpPath(), excelFile).toFile()

                val mdWriter = mdTempFile.printWriter()
                mdWriter.println("<!-- ======================================================= -->")
                mdWriter.println("<!-- -- Database: $dbname -->")
                mdWriter.println("<!-- -- Date: ${LocalDateTime.now()} -->")
                mdWriter.println("<!-- -- =======================================================-->")
                mdWriter.println()

                mdWriter.flush()

                mdWriter.println()
                mdWriter.printf("# %s\n\n", dbname)

                mdWriter.println("## Tables")
                mdWriter.println()

                val tables = getTableList(connection)

                tables.forEach { tableInfo ->
                    mdWriter.printf("### %s\n\n", tableInfo.first)

                    mdWriter.println("> " + (tableInfo.second.ifBlank { "-" }))

                    mdWriter.println()

                    mdWriter.flush()

                    formatColumnDoc(connection, tableInfo.first, mdWriter)
                }

                mdWriter.println()
                mdWriter.println("<!-- The End -->")
                mdWriter.flush()
                mdWriter.close()



                // TODO 将 markdown 和 Excel 压缩到一起

                val fileid = UUID.randomUUID().toString()
                request.session!!.set(fileid, mdTempFile)

                qr.data = arrayListOf(fileid)
            }
            qr.success = true
        } catch (e: Exception) {
            qr.success = false
            qr.message = e.message
        }
        return qr
    }

    private fun formatColumnDoc(connection: Connection, tableName: String, writer: PrintWriter) {
        val columns = getTableColumns(connection, tableName)

        // 定义列标题
        val titles = listOf("COLUMN_NAME", "COLUMN_TYPE", "IS_NULLABLE", "COLUMN_DEFAULT", "COLUMN_KEY", "COLUMN_COMMENT", "EXTRA")
        val headers = listOf("NAME", "TYPE", "NULLABLE", "DEFAULT", "KEY", "COMMENT", "EXTRA")

        // 计算每一列的最大宽度
        val maxWidth = titles.associateWith {
            it.length // 初始宽度为列名的长度
        }.toMutableMap()

        columns.forEach { column ->
            titles.forEach { title ->
                val length = Misc.calculateUnicodeLength(column[title] ?: "")
                maxWidth[title] = maxOf(maxWidth[title] ?: 0, length) // 更新最大宽度
            }
        }

        // 输出标题行
        writer.println(createRow(headers, maxWidth, true))

        // 输出分隔行
        writer.println(createSeparator(maxWidth))

        // 输出每一行
        for (column in columns) {
            writer.println(createRow(titles.map { column[it] ?: "" }, maxWidth, false))
        }

        writer.println()
    }

    // 创建单行输出
    private fun createRow(values: List<String>, maxWidth: Map<String, Int>, isHeader: Boolean): String {
        return values.indices.joinToString("|") { index ->
            val title = maxWidth.keys.elementAt(index)
            String.format(" %-${maxWidth[title]?.plus(if (isHeader) 2 else 0)}s ", values[index])
        }.let { "|$it|" }
    }

    // 创建分隔行
    private fun createSeparator(maxWidth: Map<String, Int>): String {
        return maxWidth.keys.joinToString("|") { title -> "-".repeat(maxWidth[title]!! + 2) }.let { "|$it|" }
    }

    @Throws(SQLException::class)
    fun export(request: RequestContext, params: DbmRequestForm): Any {
        val tables = params.table.split(',')
        val qr = QueryResult()

        try {
            MySQL.connect(params.options) { connection: Connection ->
                val dbname = connection.catalog
                val filename = "webdbm_export_${dbname}_${
                    LocalDateTime.now().format(
                        DateTimeFormatter.ofPattern("yyyyMMdd_HHmmss")
                    )
                }.sql"

                val tempFile = Paths.get(Misc.getTmpPath(), filename).toFile()
                tempFile.printWriter().use {
                    val writer = it
                    writer.println("-- =======================================================")
                    writer.println("-- Database: $dbname")
                    writer.println("-- Date: ${LocalDateTime.now()}")
                    writer.println("-- =======================================================")
                    writer.println()

                    // 创建数据库
//                    val stmt = connection.createStatement()
//                    val result = stmt.executeQuery("show create database `${dbname}`")
//                    result.next()
//                    val createSql = result.getString(2)
//                    stmt.close()
//
//                    writer.println("$createSql;")
//                    writer.println()

//                    writer.println("use `${dbname}`;")
//                    writer.println()

                    writer.println("SET NAMES utf8mb4;")
                    writer.println("SET FOREIGN_KEY_CHECKS = 0;")

                    writer.flush()

                    tables.forEach { table: String ->
                        fetchTable(connection, table, it)
                        writer.println()
                        writer.flush()
                    }

                    writer.println()
                    writer.println("SET FOREIGN_KEY_CHECKS = 1;")
                    writer.flush()
                }

                // 压缩文件
                val zipFile = tempFile.absoluteFile.absolutePath + ".zip"

                Misc.zipFile(tempFile, zipFile)
                tempFile.delete()

                val fileid = UUID.randomUUID().toString()
                request.session!!.set(fileid, zipFile)

                qr.data = arrayListOf(fileid)
            }
            qr.success = true
        } catch (e: Exception) {
            qr.success = false
            qr.message = e.message
        }
        return qr
    }

    private fun fetchTable(connection: Connection, table: String, writer: PrintWriter) {
        val stmt = connection.createStatement()

        writer.println("-- --------------------------")
        writer.println("-- Table structure for $table")
        writer.println("-- --------------------------")

        // 销毁表
        writer.println("DROP TABLE IF EXISTS `$table`;")

        // 创建表
        var result = stmt.executeQuery("show create table `$table`")
        result.next()
        val createSql = result.getString(2)
        writer.println("$createSql;")
        writer.println()

        writer.println("-- ----------------------------")
        writer.println("-- Records of $table")
        writer.println("-- ----------------------------")

        // 数据
        result = stmt.executeQuery("select * from `$table`")

        val meta = result.metaData
        val colCount = meta.columnCount

        val columns = arrayOfNulls<String>(colCount)
        val colTypes = IntArray(colCount)

        for (i in 0 until colCount) {
            columns[i] = meta.getColumnName(i + 1)
            colTypes[i] = meta.getColumnType(i + 1)
        }
        val colStr = "`" + columns.joinToString("`, `") + "`"

        // 每 100 条数据 flush 一次
        var count = 0

        while (result.next()) {
            // 处理行数据
            val cells = arrayOfNulls<String>(colCount)
            colTypes.forEachIndexed { i: Int, type: Int ->
                val pair = MySQL.readCell(type, result, i)
                cells[i] = if (pair.second) {
                    // 处理字符串中的 ' 符号
                    "'${
                        pair.first.map {
                            when (it) {
                                '\\' -> "\\\\"
                                '\'' -> "\\\'"
                                '\"' -> "\\\""
                                else -> it.toString()
                            }
                        }.joinToString("")
                    }'"
                } else {
                    pair.first
                }
            }

            val values = cells.joinToString(", ")
            writer.println("INSERT INTO `$table`(${colStr}) VALUES(${values});")

            count++

            if (count == 100) {
                writer.flush()
                count = 0
            }
        }
        stmt.close()
    }

    private fun getTableList(connection: Connection): List<Pair<String, String>> {
        val stmt = connection.createStatement()
        val result = stmt.executeQuery(
            "select TABLE_NAME, TABLE_COMMENT from information_schema.`tables` where table_schema='#database#';".replace(
                "#database#".toRegex(),
                connection.catalog
            )
        )

        val data = mutableListOf<Pair<String, String>>()

        while (result.next()) {
            data.add(Pair(result.getString(1), result.getString(2)))
        }

        stmt.close()

        return data
    }

    private fun getTableColumns(connection: Connection, tableName: String): List<HashMap<String, String>> {
        val stmt = connection.createStatement()
        val result = stmt.executeQuery(
            (
                    "select" +
                            " COLUMN_NAME, IS_NULLABLE, COLUMN_DEFAULT, COLUMN_TYPE, COLUMN_KEY, EXTRA, COLUMN_COMMENT" +
                            " from information_schema.columns" +
                            " where table_schema='#database#' and table_name='#table_name#'" +
                            " order by ORDINAL_POSITION"
                    )
                .replace(
                    "#database#".toRegex(),
                    connection.catalog
                ).replace(
                    "#table_name#".toRegex(),
                    tableName
                )
        )

        val data = mutableListOf<HashMap<String, String>>()

        while (result.next()) {
            val col = HashMap<String, String>()

            col["COLUMN_NAME"] = result.getString("COLUMN_NAME")
            col["IS_NULLABLE"] = result.getString("IS_NULLABLE")
            col["COLUMN_DEFAULT"] = result.getString("COLUMN_DEFAULT") ?: "-"
            col["COLUMN_TYPE"] = result.getString("COLUMN_TYPE")
            col["COLUMN_KEY"] = result.getString("COLUMN_KEY").ifBlank { "-" }
            col["COLUMN_COMMENT"] = result.getString("COLUMN_COMMENT").replace("\n", "<br/>").ifBlank { "-" }
            col["EXTRA"] = result.getString("EXTRA").ifBlank { "-" }

            data.add(col)
        }

        return data
    }
}